"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


create the synthetic example with a continuous change (Fig. 4) and perform
its clustering using the flow stability and the multilayer modularity
    

"""
import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import numpy as np
from SynthTempNetwork import Individual, SynthTempNetwork
from TemporalNetwork import ContTempNetwork
import pickle
import time
import matplotlib.pyplot as plt
from TemporalStability import static_clustering
import igraph as ig
    
import leidenalg as la

from TemporalStability import FlowIntegralClustering, avg_norm_var_information
import random

from scipy.linalg import expm
from scipy import integrate
from TemporalStability import Clustering

def Round_To_n(x, n):
    return round(x, -int(np.floor(np.sign(x) * np.log10(abs(x)))) + n)

raise Exception
#%%



savedir = '../paper_data/synthtempnet/toy_model_comparison'

figdir =  '../figures/synthtemp/toy_model_comparison'

os.makedirs(savedir, exist_ok=True)

os.makedirs(figdir, exist_ok=True)

filename = 'synthtemp_toymod_comp'


num_sims = 10

#%%


# start with two clusters: (0,1,2,3) & (4,5,6,7)
# continuous change to :   (0,1,6,7) & (2,3,4,5)

def make_step_block_probs_linear_change(t_end, t0=0, p=1):
    """
    """


    def block_mod_func(t):
        p0 = 1/2-p
        
        assert p0 >= 0

        
        # groups 1-2 and 3-4
        B0 = np.ones((4,4))*p0
        np.fill_diagonal(B0,p)
        B0[0,1] = p
        B0[1,0] = p
        B0[2,3] = p
        B0[3,2] = p
        
        
        B1 = np.ones((4,4))*p0
        np.fill_diagonal(B1,p)
        B1[0,3] = p
        B1[1,2] = p
        B1[2,1] = p
        B1[3,0] = p
        
            
        #interpolate between B0 and B1 between t0 and t_end
        
        def f1(t,tend):
                
            return ((tend-t)/tend)*B0 + (t/tend)*B1
            
        if t < t0:
            return B0
        elif t >=t0 and t <=t_end:
            return f1(t-t0,t_end-t0)

        else:
            print('Warning : t must be >=0 and <= t_end' +\
                  't is ', t)
            return B1
        
    return block_mod_func


inter_tau = 1
activ_tau = 1
t_start = 0
n_groups = 4
n_per_group = 2



def make_activ_tau_linear_change(t_end, t0=0, tau_start=1, tau_end=2):
    """
    """
    return lambda t: (0, tau_start + t*(tau_end-tau_start)/(t_end-t0))

# no noise
p=0.5

# with noise
# p=0.5*0.95

t_end = 100*(activ_tau)

activ_tau_start_1=1
activ_tau_stop_1=2

activ_tau_start_2=2
activ_tau_stop_2=1


block_prob_mod_func = make_step_block_probs_linear_change(t_end,p=p)

individuals = []
for g, tau_start, tau_end in zip(range(n_groups),[activ_tau_start_1, activ_tau_start_1,
                              activ_tau_start_2, activ_tau_start_2],
                                 [activ_tau_stop_1, activ_tau_stop_1,
                                  activ_tau_stop_2, activ_tau_stop_2]):

    individuals.extend([Individual(i, inter_distro_scale=inter_tau,
                                      activ_distro_mod_func=make_activ_tau_linear_change(t_end,0,tau_start,tau_end),
                                      group=g) for i in range(g*n_per_group,(g+1)*n_per_group)])



#%%
with open(os.path.join(savedir, filename + '_sim_param.pickle'), 'wb') as fopen:
    pickle.dump({'inter_tau' : inter_tau,
                'activ_tau' : activ_tau,
                't_start' : t_start,
                't_end' : t_end,
                'n_groups' : n_groups,
                'n_per_group' : n_per_group,
                'p' : p,
                'activ_tau_start_1' : activ_tau_start_1,
                'activ_tau_stop_1' : activ_tau_stop_1,
                'activ_tau_stop_2' : activ_tau_stop_2,
                'activ_tau_start_2' : activ_tau_start_2,
                }, fopen)
#%% run sim    
for i in range(num_sims):
    sim = SynthTempNetwork(individuals=individuals, t_start=t_start, t_end=t_end,
                           next_event_method='block_probs_mod',
                           block_prob_mod_func=block_prob_mod_func)
    
    print('running simulation')
    t0 = time.time()
    sim.run(save_all_states=True, save_dt_states=True, verbose=False)
    print('done in ', time.time()-t0)
    
    net = ContTempNetwork(source_nodes=sim.indiv_sources,
                          target_nodes=sim.indiv_targets,
                          starting_times=sim.start_times,
                          ending_times=sim.end_times,
                          merge_overlapping_events=True)
        
    net.events_table.head(10)
    
    


    net.save(os.path.join(savedir, filename + f'{i:02d}' + '.pickle'))
    
        
#%% load networks        
    
with open(os.path.join(savedir, filename + '_sim_param.pickle'), 'rb') as fopen:
# with open(os.path.join(savedir, filename + '_sim_param_noisy.pickle'), 'rb') as fopen:    
    sim_param = pickle.load(fopen)

nets = []
for i in range(num_sims):
    nets.append(ContTempNetwork.load(os.path.join(savedir, filename + f'{i:02d}' + '.pickle')))
    
    
    
#%% clustering


num_repeat = 50

#%% flow stab clustering

clusters_forw = {}
clusters_back = {}
stab_forw = {}
stab_back = {}

for sim in range(0,num_sims):
    
    
    clusters_forw[sim] = {}
    clusters_back[sim] = {}
    stab_forw[sim] = {}
    stab_back[sim] = {}
    
    for lamda in map(lambda x: Round_To_n(x,3),
                      np.logspace(np.log10(0.005), np.log10(0.1),num=10)):
    # for lamda in [0.08]:        
    
        clusters_forw[sim][lamda] = []
        clusters_back[sim][lamda] = []
        stab_forw[sim][lamda] = []
        stab_back[sim][lamda] = []
        
        print(sim,lamda)
        nets[sim].compute_inter_transition_matrices(lamda=lamda, verbose=True)
        
        for i in range(num_repeat):
            print(i)
            
            fclust_forw = FlowIntegralClustering(time_list=nets[sim].times,
                                            # T_inter_list=[T.toarray() for T in nets[sim].inter_T[lamda]])
                                            T_inter_list=[T.toarray() for T in nets[sim].inter_T[lamda]])
                                            
            
            fclust_forw.find_louvain_clustering()
            clusters_forw[sim][lamda].append(fclust_forw.partition[0].cluster_list)
            stab_forw[sim][lamda].append(fclust_forw.clustering[0].compute_stability())
            
            fclust_back = FlowIntegralClustering(time_list=nets[sim].times,
                                            T_inter_list=[T.toarray() for T in nets[sim].inter_T[lamda]],
                                            reverse_time=True)
                                            
            fclust_back.find_louvain_clustering()

            clusters_back[sim][lamda].append(fclust_back.partition[0].cluster_list)
            stab_back[sim][lamda].append(fclust_back.clustering[0].compute_stability())
    
    
#%%
best_clusts_forw = {}
best_clusts_back = {}

num_clusts_forw = {}
num_clusts_back = {}

avg_nmi_forw = {}
avg_nmi_back = {}
    
for sim in range(num_sims):
    lambdas = sorted(stab_forw[sim].keys())
    best_clusts_forw[sim] = [clusters_forw[sim][l][np.argmax(stab_forw[sim][l])] for l in lambdas]
    best_clusts_back[sim] = [clusters_back[sim][l][np.argmax(stab_back[sim][l])] for l in lambdas]
    
    num_clusts_forw[sim] = [len(c) for c in best_clusts_forw[sim]]
    num_clusts_back[sim] = [len(c) for c in best_clusts_back[sim]]
    
    avg_nmi_forw[sim] = [avg_norm_var_information(clusters_forw[sim][l]) for l in lambdas]
    avg_nmi_back[sim] = [avg_norm_var_information(clusters_back[sim][l]) for l in lambdas]
    
#%%    
with open(os.path.join(savedir, filename + '_flow_stab_res.pickle'), 'wb') as fopen:
    pickle.dump({'clusters_forw' :clusters_forw,
                'clusters_back' :clusters_back,
                'stab_forw' :stab_forw,
                'stab_back' :stab_back,
                'best_clusts_forw' :best_clusts_forw,
                'best_clusts_back' :best_clusts_back,
                'num_clusts_forw' :num_clusts_forw,
                'num_clusts_back' :num_clusts_back,
                'avg_nmi_forw' :avg_nmi_forw,
                'avg_nmi_back' :avg_nmi_back,
                'lambdas' :lambdas
                }, fopen)

#%% static clustering


A = nets[0].compute_static_adjacency_matrix()
plt.matshow(A.toarray())

clusters_stat = {}
stab_stat = {}

for t in map(lambda x: Round_To_n(x,3),
                 np.logspace(np.log10(0.1), np.log10(2),num=10)):
    
    clusters_stat[t] = []
    stab_stat[t] = []
    for i in range(num_repeat):
        stat_clust = static_clustering(A,t=t,linearized=True)
        stat_clust.find_louvain_clustering()
        clusters_stat[t].append(stat_clust.partition.cluster_list)
        stab_stat[t].append(stat_clust.compute_stability())
        
#%%

ts = sorted(stab_stat.keys())
best_clusts_stat = [clusters_stat[l][np.argmax(stab_stat[l])] for l in ts]

num_clusts_stat = [len(c) for c in best_clusts_stat]

avg_nmi_stat = [avg_norm_var_information(clusters_stat[l]) for l in ts]

    
with open(os.path.join(savedir, filename + '_static_res.pickle'), 'wb') as fopen:
    pickle.dump({'clusters_stat' :clusters_stat,
                'stab_stat' :stab_stat,
                'best_clusts_stat' :best_clusts_stat,
                'num_clusts_stat' :num_clusts_stat,
                'avg_nmi_stat' :avg_nmi_stat,
                }, fopen)



#%% multislice
t0 = sim_param['t_start']

tend = sim_param['t_end']

num_slices=5


time_slices = np.linspace(sim_param['t_start'],sim_param['t_end'],num_slices+1)



t_starts = time_slices[:-1]
t_stops = time_slices[1:]

adjacencies = {}
graphs = {}
for sim in range(num_sims):
    adjacencies[sim] = []
    
    for ts,te in zip(t_starts,t_stops):
        A = nets[sim].compute_static_adjacency_matrix(start_time=ts, 
                                                end_time=te).toarray()
        if not (A == np.zeros_like(A)).all():
            adjacencies[sim].append((A))
    




    graphs[sim] = []
    
    for A in adjacencies[sim]:
        g = ig.Graph()
        g.add_vertices(A.shape[0])
        g.vs['id'] = list(range(A.shape[0]))
        for i in range(A.shape[0]):
            for j in range(A.shape[1]):
                if j > i :
                    if A[i,j]>0:
                        g.add_edge(i,j, weight=A[i,j])
        graphs[sim].append(g)
    
    
    

#%%
num_repeat = 50

clusters_multi = {}
stab_multi = {}

for sim in range(num_sims):
    avg_weight = np.mean([A[A>0].mean() for A in adjacencies[sim]])

    interslice_weight = avg_weight*0.1
    
    
    layers, interslice_layer, G_full = \
                la.time_slices_to_layers(graphs[sim],
                                interslice_weight=interslice_weight)
#%%            

    clusters_multi[sim] = {}
    stab_multi[sim] = {}

    for resolution_parameter in map(lambda x: Round_To_n(x,3),
                     np.logspace(np.log10(0.1), np.log10(5),num=10)):
        
        print(resolution_parameter)
        # if resolution_parameter not in res.keys():
        clusters_multi[sim][resolution_parameter] = []
        stab_multi[sim][resolution_parameter] = []
        
        for i in range(num_repeat):
            print(i)
            optimiser = la.Optimiser()
                
            optimiser.set_rng_seed(int.from_bytes(os.urandom(3), byteorder="big"))
                           
            partitions = [la.RBConfigurationVertexPartition(H,
                                                       weights='weight',
                                      resolution_parameter=resolution_parameter)
                                for H in layers];               
            
            interslice_partition = \
                           la.RBConfigurationVertexPartition(interslice_layer, 
                                                             resolution_parameter=0,
                                                             # node_sizes='node_size',
                                                             weights='weight')
                           
                           
                           
            diff = optimiser.optimise_partition_multiplex(partitions + [interslice_partition],
                                                          n_iterations=2)            
                           
            clusters_multi[sim][resolution_parameter].append(partitions[0].membership)
            stab_multi[sim][resolution_parameter].append(sum([p.modularity for p in partitions + [interslice_partition]]))
            
        
#%%
res_params = {}
best_clusts_multi = {}
num_clusts_multi = {}
avg_nmi_multi = {}

for sim in range(num_sims):
    res_params[sim] = sorted(stab_multi[sim].keys())
    best_clusts_multi[sim] = [clusters_multi[sim][l][np.argmax(stab_multi[sim][l])] for l in res_params[sim]]
    
    num_clusts_multi[sim] = [max(c)+1 for c in best_clusts_multi[sim]]
    
    from TemporalStability import Partition
    
    parts_multi = {}
    for l in res_params[sim]:
        parts_multi[l] = []
        for clusts in clusters_multi[sim][l]:
            parts_multi[l].append(Partition(num_nodes=(len(time_slices)-1)*nets[0].num_nodes,
                                            node_to_cluster_dict={i:c for i,c in enumerate(clusts)}).cluster_list)
    
    
    avg_nmi_multi[sim] = [avg_norm_var_information(parts_multi[l]) for l in res_params[sim]]

#%%    
# with open(os.path.join(savedir, filename + '_multislice_res.pickle'), 'wb') as fopen:
with open(os.path.join(savedir, filename + f'_multislice_res_num_slice{num_slices}_intersweight{interslice_weight:.3f}.pickle'), 'wb') as fopen:    
    pickle.dump({'clusters_multi' :clusters_multi,
                'stab_multi' :stab_multi,
                'best_clusts_multi' :best_clusts_multi,
                'num_clusts_multi' :num_clusts_multi,
                'avg_nmi_multi' :avg_nmi_multi,
                'res_params' : res_params,
                'interslice_weight': interslice_weight,
                'num_slices': num_slices,
                }, fopen)
    
#%%
plt.figure()
for sim in range(num_sims):


    plt.plot(res_params[sim], avg_nmi_multi[sim], 'o-')
    plt.xscale('log')
    
plt.figure()
for sim in range(num_sims):


    plt.plot(res_params[sim], num_clusts_multi[sim], 'o-')
    plt.xscale('log')
    
#%%
plt.figure()


plt.errorbar(res_params[sim], [np.mean([avg_nmi_multi[sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_params[sim]))],
             yerr=[np.std([avg_nmi_multi[sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_params[sim]))],
                 capsize=10)
plt.xscale('log')

plt.figure()
plt.errorbar(res_params[sim], [np.mean([num_clusts_multi[sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_params[sim]))],
             yerr=[np.std([num_clusts_multi[sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_params[sim]))],
                 capsize=10)
plt.xscale('log')


#%%
for sim in range(num_sims):
    clust_matrix = np.array(best_clusts_multi[sim][6]).reshape((len(time_slices)-1,
                                                              nets[0].num_nodes)).T
    plt.figure()
    plt.imshow(clust_matrix)


#%% for one simulation

reso = res_params[0][5]

from collections import Counter

sim = 5

clust_list = [Partition(num_nodes=(len(time_slices)-1)*nets[0].num_nodes,
                                node_to_cluster_dict={i:c for i,c in enumerate(c)}).cluster_list for c in clusters_multi[sim][reso]]

countm = Counter([tuple(sorted([tuple(sorted(list(c))) for c in c_list])) for c_list in clust_list])
countm.most_common(10)


